{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#  Gated recurrent unit (GRU) RNNs\n",
    "\n",
    "This chapter requires some exposition. The GRU updates are fully implemented and the code appears to work properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "from mxnet import nd, autograd\n",
    "import numpy as np\n",
    "mx.random.seed(1)\n",
    "ctx = mx.gpu(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset: \"The Time Machine\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open(\"../data/nlp/timemachine.txt\") as f:\n",
    "    time_machine = f.read()\n",
    "time_machine = time_machine[:-38083]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numerical representations of characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "character_list = list(set(time_machine))\n",
    "vocab_size = len(character_list)\n",
    "character_dict = {}\n",
    "for e, char in enumerate(character_list):\n",
    "    character_dict[char] = e\n",
    "time_numerical = [character_dict[char] for char in time_machine]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One-hot representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def one_hots(numerical_list, vocab_size=vocab_size):\n",
    "    result = nd.zeros((len(numerical_list), vocab_size), ctx=ctx)\n",
    "    for i, idx in enumerate(numerical_list):\n",
    "        result[i, idx] = 1.0\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def textify(embedding):\n",
    "    result = \"\"\n",
    "    indices = nd.argmax(embedding, axis=1).asnumpy()\n",
    "    for idx in indices:\n",
    "        result += character_list[int(idx)]\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing the data for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "seq_length = 64\n",
    "# -1 here so we have enough characters for labels later\n",
    "num_samples = (len(time_numerical) - 1) // seq_length\n",
    "dataset = one_hots(time_numerical[:seq_length*num_samples]).reshape((num_samples, seq_length, vocab_size))\n",
    "num_batches = len(dataset) // batch_size\n",
    "train_data = dataset[:num_batches*batch_size].reshape((num_batches, batch_size, seq_length, vocab_size))\n",
    "# swap batch_size and seq_length axis to make later access easier\n",
    "train_data = nd.swapaxes(train_data, 1, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing our labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "labels = one_hots(time_numerical[1:seq_length*num_samples+1])\n",
    "train_label = labels.reshape((num_batches, batch_size, seq_length, vocab_size))\n",
    "train_label = nd.swapaxes(train_label, 1, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gated recurrent units (GRU) RNNs\n",
    "\n",
    "Similar to LSTM blocks, the GRU also has mechanisms to enable \"memorizing\" information for an extended number of time steps. However, it does so in a more expedient way:\n",
    "\n",
    "* We no longer keep a separate memory cell $c_t$. Instead, $h_{t-1}$ is added to a \"new content\" version of itself to give $h_{t}$.\n",
    "* The \"new content\" version is given by $g_t = \\text{tanh}(X_t W_{xh} + (r_t \\odot h_{t-1})W_{hh} + b_h ),$ and is analogous to $g_t$ in the LSTM tutorial.\n",
    "* Here, there is a reset gate $r_t$ which moderates the impact of $h_{t-1}$ on the \"new content\" version.\n",
    "* The input gate $i_t$ and forget gate $f_t$ are replaced by an single update gate $z_t$, which weighs the old and new content via $z_t$ and $(1 - z_t)$ respectively.\n",
    "* There is no output gate $o_t$; the weighted sum is what becomes $h_t$.\n",
    "\n",
    "We use the GRU block with the following transformations that map inputs to outputs across blocks at consecutive layers and consecutive time steps: $\\newcommand{\\xb}{\\mathbf{x}} \\newcommand{\\RR}{\\mathbb{R}}$\n",
    "\n",
    "$$z_t = \\sigma(X_t W_{xz} + h_{t-1} W_{hz} + b_z),$$\n",
    "\n",
    "$$r_t = \\sigma(X_t W_{xr} + h_{t-1} W_{hr} + b_r),$$\n",
    "\n",
    "$$g_t = \\text{tanh}(X_t W_{xh} + (r_t \\odot h_{t-1})W_{hh} + b_h ),$$\n",
    "\n",
    "$$h_t = z_t \\odot h_{t-1} + (1-z_t) \\odot g_t,$$\n",
    "\n",
    "where $\\sigma$ and $\\text{tanh}$ are as before in the LSTM case. \n",
    "\n",
    "Empirically, GRUs have similar performance to LSTMs, while requiring less parameters and forgoing an internal time state. Intuitively, GRUs have enough gates/state for long-term retention, but not too much, so that training and convergence remain fast and convex. See the work of Chung et al. [2014] (https://arxiv.org/abs/1412.3555)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Allocate parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_inputs = vocab_size\n",
    "num_hidden = 256\n",
    "num_outputs = vocab_size\n",
    "\n",
    "########################\n",
    "#  Weights connecting the inputs to the hidden layer\n",
    "########################\n",
    "Wxz = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01\n",
    "Wxr = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01\n",
    "Wxh = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01\n",
    "\n",
    "########################\n",
    "#  Recurrent weights connecting the hidden layer across time steps\n",
    "########################\n",
    "Whz = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01\n",
    "Whr = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01\n",
    "Whh = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01\n",
    "\n",
    "########################\n",
    "#  Bias vector for hidden layer\n",
    "########################\n",
    "bz = nd.random_normal(shape=num_hidden, ctx=ctx) * .01\n",
    "br = nd.random_normal(shape=num_hidden, ctx=ctx) * .01\n",
    "bh = nd.random_normal(shape=num_hidden, ctx=ctx) * .01\n",
    "\n",
    "########################\n",
    "# Weights to the output nodes\n",
    "########################\n",
    "Why = nd.random_normal(shape=(num_hidden,num_outputs), ctx=ctx) * .01\n",
    "by = nd.random_normal(shape=num_outputs, ctx=ctx) * .01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attach the gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = [Wxz, Wxr, Wxh, Whz, Whr, Whh, bz, br, bh, Why, by]\n",
    "\n",
    "for param in params:\n",
    "    param.attach_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax Activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def softmax(y_linear, temperature=1.0):\n",
    "    lin = (y_linear-nd.max(y_linear)) / temperature\n",
    "    exp = nd.exp(lin)\n",
    "    partition = nd.sum(exp, axis=0, exclude=True).reshape((-1,1))\n",
    "    return exp / partition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def gru_rnn(inputs, h, temperature=1.0):\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        z = nd.sigmoid(nd.dot(X, Wxz) + nd.dot(h, Whz) + bz)\n",
    "        r = nd.sigmoid(nd.dot(X, Wxr) + nd.dot(h, Whr) + br)\n",
    "        g = nd.tanh(nd.dot(X, Wxh) + nd.dot(r * h, Whh) + bh)\n",
    "        h = z * h + (1 - z) * g\n",
    "        \n",
    "        yhat_linear = nd.dot(h, Why) + by\n",
    "        yhat = softmax(yhat_linear, temperature=temperature) \n",
    "        outputs.append(yhat)\n",
    "    return (outputs, h)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross-entropy loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cross_entropy(yhat, y):\n",
    "    return - nd.mean(nd.sum(y * nd.log(yhat), axis=0, exclude=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Averaging the loss over the sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def average_ce_loss(outputs, labels):\n",
    "    assert(len(outputs) == len(labels))\n",
    "    total_loss = nd.array([0.], ctx=ctx)\n",
    "    for (output, label) in zip(outputs,labels):\n",
    "        total_loss = total_loss + cross_entropy(output, label)\n",
    "    return total_loss / len(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def SGD(params, lr):    \n",
    "    for param in params:\n",
    "        param[:] = param - lr * param.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating text by sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def sample(prefix, num_chars, temperature=1.0):\n",
    "    #####################################\n",
    "    # Initialize the string that we'll return to the supplied prefix\n",
    "    #####################################\n",
    "    string = prefix\n",
    "\n",
    "    #####################################\n",
    "    # Prepare the prefix as a sequence of one-hots for ingestion by RNN\n",
    "    #####################################\n",
    "    prefix_numerical = [character_dict[char] for char in prefix]\n",
    "    input_sequence = one_hots(prefix_numerical)\n",
    "    \n",
    "    #####################################\n",
    "    # Set the initial state of the hidden representation ($h_0$) to the zero vector\n",
    "    #####################################    \n",
    "    h = nd.zeros(shape=(1, num_hidden), ctx=ctx)\n",
    "    c = nd.zeros(shape=(1, num_hidden), ctx=ctx)\n",
    "\n",
    "    #####################################\n",
    "    # For num_chars iterations,\n",
    "    #     1) feed in the current input\n",
    "    #     2) sample next character from from output distribution\n",
    "    #     3) add sampled character to the decoded string\n",
    "    #     4) prepare the sampled character as a one_hot (to be the next input)\n",
    "    #####################################    \n",
    "    for i in range(num_chars):\n",
    "        outputs, h = gru_rnn(input_sequence, h, temperature=temperature)\n",
    "        choice = np.random.choice(vocab_size, p=outputs[-1][0].asnumpy())\n",
    "        string += character_list[choice]\n",
    "        input_sequence = one_hots([choice])\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "epochs = 2000\n",
    "moving_loss = 0.\n",
    "\n",
    "learning_rate = 2.0\n",
    "\n",
    "# state = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)\n",
    "for e in range(epochs):\n",
    "    ############################\n",
    "    # Attenuate the learning rate by a factor of 2 every 100 epochs.\n",
    "    ############################\n",
    "    if ((e+1) % 100 == 0):\n",
    "        learning_rate = learning_rate / 2.0\n",
    "    h = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)\n",
    "    for i in range(num_batches):\n",
    "        data_one_hot = train_data[i]\n",
    "        label_one_hot = train_label[i]\n",
    "        with autograd.record():\n",
    "            outputs, h = gru_rnn(data_one_hot, h)\n",
    "            loss = average_ce_loss(outputs, label_one_hot)\n",
    "            loss.backward()\n",
    "        SGD(params, learning_rate)\n",
    "\n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        if (i == 0) and (e == 0):\n",
    "            moving_loss = nd.mean(loss).asscalar()\n",
    "        else:\n",
    "            moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar()\n",
    "      \n",
    "    print(\"Epoch %s. Loss: %s\" % (e, moving_loss)) \n",
    "    print(sample(\"The Time Ma\", 1024, temperature=.1))\n",
    "    print(sample(\"The Medical Man rose, came to the lamp,\", 1024, temperature=.1))\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions\n",
    "[Placeholder]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Simple, LSTM, and GRU RNNs with gluon](../chapter05_recurrent-neural-networks/rnns-gluon.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
